More PyMC

Lecture 20

Dr. Colin Rundel

Demo 1 - Logistic Regression





Based on PyMC Out-Of-Sample Predictions example

Data

           x1        x2  y
0   -3.207674  0.859021  0
1    0.128200  2.827588  0
2    1.481783 -0.116956  0
3    0.305238 -1.378604  0
4    1.727488 -0.926357  1
..        ...       ... ..
245 -2.182813  3.314672  0
246 -2.362568  2.078652  0
247  0.114571  2.249021  0
248  2.093975 -1.212528  1
249  1.241667 -2.363412  0

[250 rows x 3 columns]

Test-train split

from sklearn.model_selection import train_test_split

y, X = patsy.dmatrices("y ~ x1 * x2", data=df)

X_lab = X.design_info.column_names
y_lab = y.design_info.column_names
y = np.asarray(y).flatten()
X = np.asarray(X)

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7)

Model

with pm.Model(coords = {"coeffs": X_lab}) as model:
    # data containers
    X = pm.MutableData("X", X_train)
    y = pm.MutableData("y", y_train)
    # priors
    b = pm.Normal("b", mu=0, sigma=3, dims="coeffs")
    # linear model
    mu = X @ b
    # link function
    p = pm.Deterministic("p", pm.math.invlogit(mu))
    # likelihood
    pm.Bernoulli("obs", p=p, observed=y)
obs

Visualizing models

pm.model_to_graphviz(model)

cluster175 x 4 175 x 4 clustercoeffs (4) coeffs (4) cluster175 175 X X~MutableData p p~Deterministic X->p obs obs~Bernoulli p->obs y y~MutableData obs->y b b~Normal b->p

Fitting

with model:
    post = pm.sample(progressbar=False, random_seed=1234)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [b]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
az.summary(post)
               mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
b[Intercept] -0.766  0.306  -1.380   -0.228      0.007    0.005    1976.0    2482.0    1.0
b[x1]         1.147  0.297   0.613    1.703      0.007    0.005    1653.0    2352.0    1.0
b[x2]        -1.315  0.301  -1.846   -0.722      0.007    0.005    1826.0    2279.0    1.0
b[x1:x2]      2.235  0.414   1.434    2.956      0.011    0.008    1534.0    1999.0    1.0
p[0]          1.000  0.000   1.000    1.000      0.000    0.000    1687.0    2141.0    1.0
...             ...    ...     ...      ...        ...      ...       ...       ...    ...
p[170]        0.575  0.101   0.384    0.759      0.002    0.001    3446.0    3016.0    1.0
p[171]        0.000  0.000   0.000    0.000      0.000    0.000    1383.0    1674.0    1.0
p[172]        0.465  0.064   0.348    0.585      0.001    0.001    3636.0    3218.0    1.0
p[173]        0.001  0.002   0.000    0.003      0.000    0.000    1332.0    1711.0    1.0
p[174]        0.796  0.071   0.668    0.923      0.001    0.001    2841.0    2664.0    1.0

[179 rows x 9 columns]

Trace plots

ax = az.plot_trace(post, var_names="b", compact=False)
plt.show()

Posterior plots

ax = az.plot_posterior(
    post, var_names=["b"], ref_val=[intercept, beta_x1, beta_x2, beta_interaction], figsize=(15, 4)
)
plt.show()

Out-of-sample predictions

post
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000, coeffs: 4, p_dim_0: 175)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * coeffs   (coeffs) <U9 'Intercept' 'x1' 'x2' 'x1:x2'
        * p_dim_0  (p_dim_0) int64 0 1 2 3 4 5 6 7 ... 167 168 169 170 171 172 173 174
      Data variables:
          b        (chain, draw, coeffs) float64 -0.1456 0.8982 -1.347 ... -1.17 2.172
          p        (chain, draw, p_dim_0) float64 1.0 2.057e-05 ... 0.0003585 0.7274
      Attributes:
          created_at:                 2023-03-27T21:24:00.194281
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.1.2
          sampling_time:              0.8028221130371094
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:                (chain: 4, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 0 1 2 3
        * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999
      Data variables: (12/17)
          energy_error           (chain, draw) float64 -0.2301 -0.1901 ... 0.05898
          perf_counter_diff      (chain, draw) float64 0.0001709 ... 0.0001736
          step_size              (chain, draw) float64 0.5446 0.5446 ... 0.4022 0.4022
          step_size_bar          (chain, draw) float64 0.511 0.511 ... 0.4849 0.4849
          tree_depth             (chain, draw) int64 2 2 2 2 2 3 3 2 ... 2 3 3 4 3 3 2
          diverging              (chain, draw) bool False False False ... False False
          ...                     ...
          n_steps                (chain, draw) float64 3.0 3.0 3.0 3.0 ... 7.0 7.0 3.0
          acceptance_rate        (chain, draw) float64 0.9743 1.0 ... 0.9958 0.6302
          reached_max_treedepth  (chain, draw) bool False False False ... False False
          lp                     (chain, draw) float64 -59.99 -58.18 ... -57.4 -57.6
          smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan
          perf_counter_start     (chain, draw) float64 1.663e+05 ... 1.663e+05
      Attributes:
          created_at:                 2023-03-27T21:24:00.200235
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.1.2
          sampling_time:              0.8028221130371094
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:    (obs_dim_0: 175)
      Coordinates:
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7 ... 168 169 170 171 172 173 174
      Data variables:
          obs        (obs_dim_0) int64 1 0 1 1 1 1 0 0 0 0 0 ... 0 0 1 1 0 0 1 0 0 0 1
      Attributes:
          created_at:                 2023-03-27T21:24:00.202353
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.1.2

    • <xarray.Dataset>
      Dimensions:  (X_dim_0: 175, X_dim_1: 4, y_dim_0: 175)
      Coordinates:
        * X_dim_0  (X_dim_0) int64 0 1 2 3 4 5 6 7 ... 167 168 169 170 171 172 173 174
        * X_dim_1  (X_dim_1) int64 0 1 2 3
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 167 168 169 170 171 172 173 174
      Data variables:
          X        (X_dim_0, X_dim_1) float64 1.0 -3.422 -3.397 ... -1.075 0.6637
          y        (y_dim_0) float64 1.0 0.0 1.0 1.0 1.0 1.0 ... 1.0 0.0 0.0 0.0 1.0
      Attributes:
          created_at:                 2023-03-27T21:24:00.202812
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.1.2

with model:
  pm.set_data({"X": X_test, "y": y_test})
  post = pm.sample_posterior_predictive(
    post, progressbar=False, var_names=["obs", "p"],
    extend_inferencedata = True
  )
Sampling: [obs]
post
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000, coeffs: 4, p_dim_0: 175)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * coeffs   (coeffs) <U9 'Intercept' 'x1' 'x2' 'x1:x2'
        * p_dim_0  (p_dim_0) int64 0 1 2 3 4 5 6 7 ... 167 168 169 170 171 172 173 174
      Data variables:
          b        (chain, draw, coeffs) float64 -0.1456 0.8982 -1.347 ... -1.17 2.172
          p        (chain, draw, p_dim_0) float64 1.0 2.057e-05 ... 0.0003585 0.7274
      Attributes:
          created_at:                 2023-03-27T21:24:00.194281
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.1.2
          sampling_time:              0.8028221130371094
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 1000, obs_dim_2: 75, p_dim_2: 75)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
        * obs_dim_2  (obs_dim_2) int64 0 1 2 3 4 5 6 7 8 ... 67 68 69 70 71 72 73 74
        * p_dim_2    (p_dim_2) int64 0 1 2 3 4 5 6 7 8 ... 66 67 68 69 70 71 72 73 74
      Data variables:
          obs        (chain, draw, obs_dim_2) int64 0 1 0 0 1 1 1 1 ... 0 0 1 0 1 1 0
          p          (chain, draw, p_dim_2) float64 0.4425 0.9079 ... 1.0 0.00572
      Attributes:
          created_at:                 2023-03-27T21:24:04.675786
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.1.2

    • <xarray.Dataset>
      Dimensions:                (chain: 4, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 0 1 2 3
        * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999
      Data variables: (12/17)
          energy_error           (chain, draw) float64 -0.2301 -0.1901 ... 0.05898
          perf_counter_diff      (chain, draw) float64 0.0001709 ... 0.0001736
          step_size              (chain, draw) float64 0.5446 0.5446 ... 0.4022 0.4022
          step_size_bar          (chain, draw) float64 0.511 0.511 ... 0.4849 0.4849
          tree_depth             (chain, draw) int64 2 2 2 2 2 3 3 2 ... 2 3 3 4 3 3 2
          diverging              (chain, draw) bool False False False ... False False
          ...                     ...
          n_steps                (chain, draw) float64 3.0 3.0 3.0 3.0 ... 7.0 7.0 3.0
          acceptance_rate        (chain, draw) float64 0.9743 1.0 ... 0.9958 0.6302
          reached_max_treedepth  (chain, draw) bool False False False ... False False
          lp                     (chain, draw) float64 -59.99 -58.18 ... -57.4 -57.6
          smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan
          perf_counter_start     (chain, draw) float64 1.663e+05 ... 1.663e+05
      Attributes:
          created_at:                 2023-03-27T21:24:00.200235
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.1.2
          sampling_time:              0.8028221130371094
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:    (obs_dim_0: 175)
      Coordinates:
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7 ... 168 169 170 171 172 173 174
      Data variables:
          obs        (obs_dim_0) int64 1 0 1 1 1 1 0 0 0 0 0 ... 0 0 1 1 0 0 1 0 0 0 1
      Attributes:
          created_at:                 2023-03-27T21:24:00.202353
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.1.2

    • <xarray.Dataset>
      Dimensions:  (X_dim_0: 175, X_dim_1: 4, y_dim_0: 175)
      Coordinates:
        * X_dim_0  (X_dim_0) int64 0 1 2 3 4 5 6 7 ... 167 168 169 170 171 172 173 174
        * X_dim_1  (X_dim_1) int64 0 1 2 3
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 167 168 169 170 171 172 173 174
      Data variables:
          X        (X_dim_0, X_dim_1) float64 1.0 -3.422 -3.397 ... -1.075 0.6637
          y        (y_dim_0) float64 1.0 0.0 1.0 1.0 1.0 1.0 ... 1.0 0.0 0.0 0.0 1.0
      Attributes:
          created_at:                 2023-03-27T21:24:00.202812
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.1.2

Posterior predictive summary

az.summary(
  post.posterior_predictive  
)
         mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
obs[0]  0.411  0.492   0.000    1.000      0.008    0.006    3862.0    3862.0    1.0
obs[1]  0.796  0.403   0.000    1.000      0.007    0.005    3518.0    3518.0    1.0
obs[2]  0.000  0.000   0.000    0.000      0.000    0.000    4000.0    4000.0    NaN
obs[3]  0.012  0.110   0.000    0.000      0.002    0.001    3948.0    3948.0    1.0
obs[4]  0.980  0.142   1.000    1.000      0.002    0.002    4085.0    4000.0    1.0
...       ...    ...     ...      ...        ...      ...       ...       ...    ...
p[70]   0.883  0.092   0.711    0.995      0.002    0.001    3797.0    2989.0    1.0
p[71]   0.523  0.065   0.407    0.649      0.001    0.001    3292.0    2829.0    1.0
p[72]   1.000  0.001   0.999    1.000      0.000    0.000    1720.0    2097.0    1.0
p[73]   1.000  0.000   1.000    1.000      0.000    0.000    1647.0    1928.0    1.0
p[74]   0.008  0.011   0.000    0.026      0.000    0.000    1663.0    2199.0    1.0

[150 rows x 9 columns]

/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/opt/homebrew/lib/python3.10/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in
   double_scalars
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)

Evaluation

post.posterior["p"].shape
(4, 1000, 175)
post.posterior_predictive["p"].shape
(4, 1000, 75)
p_train = post.posterior["p"].mean(dim=["chain", "draw"])
p_test  = post.posterior_predictive["p"].mean(dim=["chain", "draw"])
from sklearn.metrics import RocCurveDisplay, accuracy_score, auc, roc_curve

# Test data
fpr_test, tpr_test, thd_test = roc_curve(y_true=y_test, y_score=p_test)
auc_test = auc(fpr_test, tpr_test); auc_test

# Training data
0.9598278335724533
fpr_train, tpr_train, thd_train = roc_curve(y_true=y_train, y_score=p_train)
auc_train = auc(fpr_train, tpr_train); auc_train
0.9501569858712715

ROC Curves

fig, ax = plt.subplots()
roc = RocCurveDisplay(fpr=fpr_test, tpr=tpr_test).plot(ax=ax, label="test")
roc = RocCurveDisplay(fpr=fpr_train, tpr=tpr_train).plot(ax=ax, color="k", label="train")
plt.show()

Demo 2 - Gaussian Process

Data

d = pd.read_csv("data/gp.csv"); d
            x         y
0    0.002189  1.070772
1    0.006209  0.863336
2    0.006764  0.846165
3    0.009349  0.916748
4    0.012407  1.258828
..        ...       ...
245  0.982005 -0.540678
246  0.983324 -0.751002
247  0.992081 -0.510908
248  0.993567 -0.537508
249  0.994654 -0.621642

[250 rows x 2 columns]
D = np.array([ np.abs(xi - d.x) for xi in d.x])
I = np.eye(n)

pymc model

with pm.Model() as gp:
  nugget = pm.HalfCauchy("nugget", beta=5)
  sigma2 = pm.HalfCauchy("sigma2", beta=5)
  ls     = pm.HalfCauchy("ls",     beta=5)
  
  Sigma = I * nugget + sigma2 * np.exp(-0.5 * D**2 * ls**2)
  
  pm.MvNormal(
    "y", mu=np.zeros(n),
    cov=Sigma, observed=d.y
  )
y
with gp:
  post_nuts = pm.sample(chains=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 4 jobs)
NUTS: [nugget, sigma2, ls]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 878 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

Posterior summary

az.summary(post_nuts)
          mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
nugget   0.474  0.126   0.270    0.718      0.004    0.003    1174.0    1216.0    1.0
sigma2   4.180  2.668   1.184    8.328      0.086    0.061    1125.0    1047.0    1.0
ls      11.251  2.510   6.535   15.748      0.074    0.053    1078.0     982.0    1.0

Trace plots

ax = az.plot_trace(post_nuts)
plt.show()

Posterior Predictive

with gp:
  post_nuts = pm.sample_posterior_predictive(
    post_nuts, extend_inferencedata = True, progressbar=False
  )
Sampling: [y]
/opt/homebrew/lib/python3.10/site-packages/scipy/stats/_multivariate.py:753: RuntimeWarning: covariance is not positive-semidefinite.
  out = random_state.multivariate_normal(mean, cov, size)
 |████████████████████████████████████████| 100.00% [2000/2000 14:31<00:00]